import json
from sklearn.model_selection import train_test_split, TimeSeriesSplit
import datetime
import pandas as pd
import numpy as np
import sys
from sklearn.metrics._scorer import _BaseScorer
from sklearn.metrics import get_scorer
import inspect
import itertools
import sklearn
import collections
from collections.abc import Callable
import random
from dateutil.relativedelta import relativedelta
from itertools import product

random.seed(42)

LOG_FILENAME = 'log.txt'

def print_time(time):
    return time.strftime("%H:%M:%S")


def read_json(filepath):
    with open(filepath) as f:
        return json.load(f)


def get_class(module,classname, **kwargs):
    ##module = __import__(module)
    ##class_ = getattr(module, classname)
    class_ =  getattr(sys.modules[module], classname)
    return class_(**kwargs)

def set_attributes(obj, **kwargs):
    for key, value in kwargs.items():
        setattr(obj,key,value)
    return obj

def build_scorer(name,**kwargs) ->_BaseScorer:
    """Build an sklearn scorer instance using the class name and kwargs"""
    scorer = get_scorer(name)
    scorer = set_attributes(scorer,**kwargs)
    return scorer

        
class TestTrainSplit():
    def __init__(self,test_size):
        self.test_size = test_size

    def split(self,X,y):
        return train_test_split(X, y, test_size=self.test_size)


def pad_jagged_array(M,v):
    out = []
    n=max(map(len,M+[M]))
    for r in(M+[[]]*n)[:n]:
        r = r.values.tolist()
        out.append((r+[v]*n)[:n])
    return out

class Timer():
    """Class for stopwatch-style recording of execution times. """
    def __init__(self):
        self.time = None
    def start(self):
        self.time = datetime.datetime.now()
        return self
    def get_time(self):
        return datetime.datetime.now() - self.time 


class Debug(object):
    """
    A Class which provides methods for logging and timing. 
    Instead of printing statements, this class helps provide 
    consistent details such as which part of the code print 
    statements come from and the execution time of parts of code.
    """
    timestamp = None
    print_to_file = False
    """Set to true and the log statements will be appended to a file in the same directory as the running code. 
    A new file will be created if one does not already exist."""
    muted = False
    """Set muted to True if you want all log statements to do nothing when called"""
    muted_channels = []

    @staticmethod
    def _to_file(string):
        file = open(LOG_FILENAME, mode="a+")
        file.write(string + "\n")
        file.close()
    

    def log(message,channel=None):
        message = str(message)
        ## Do nothing if muted == True
        if Debug.muted == True:
            return None

        ## You can also add strings to the muted channel list and mute only Debug's with those channels.
        if channel in Debug.muted_channels:
            return None

        frame = inspect.stack()[1]
        location = (frame.filename[-20:] + " "+str(frame.lineno))
        output = (
            "LOG: "+
            location + " "+
            str(datetime.datetime.now()) + " "+
            message
            )
        if Debug.print_to_file == True:
            Debug._to_file(output)
        print(output)

    def start_timer():
        """Used in conjuncture with stop_timer. records the datetime when this function is called"""
        Debug.timestamp = datetime.datetime.now()
    
    def stop_timer(reset=True)->datetime.datetime:
        """Used in conjuncture with start_timer. Returns the time in ms since start_timer was called. Returns zero if start timer was not set correctly"""
        out = 0
        
        if(Debug.timestamp != None):
            out = datetime.datetime.now() - Debug.timestamp
            
        if reset == True:
            Debug.timestamp = None

        return out
    def get_time()->datetime.datetime:
        return Debug.timestamp


class IdGenerator():
    suffix = itertools.count()
    stem = str(datetime.date.today())

    @staticmethod
    def generate_id():
        return IdGenerator.stem+"_"+str(IdGenerator.suffix.__next__())

## Source:
## https://gist.githubusercontent.com/DenisVorotyntsev/6ef5a0222b71e82af8456bd2130de4d7/raw/4df32f3e1896e9eeee3face7b3eb94ca1abcbfe6/calculate_permutation_importance.py

def get_3d_features(X:list):
    x = X[0]
    feature_count = x.shape[1]
    return range(feature_count)

def scramble_3d_X(X:list,col:int):
    l = len(X)
    i = 1
    for x in X:
        x[:,col] = pd.Series(x[:,col]).sample(x.shape[0], replace=True, random_state=42+col).values
        i+=1
    return X

def calculate_permutation_importance(
    model,
    X: pd.DataFrame,
    y: pd.Series,
    scoring_function: Callable = sklearn.metrics.roc_auc_score,
    n_repeats: int = 3
    ):
    """
    Example of permutation importance calculation 
    :param model: sklearn model, or any model with `fit` and `predict` methods
    :param X: input features
    :param y: input target
    :param scoring_function: function to use for scoring, should output single float value
    :param n_repeats: how many times make permutation
    :param seed: random state for experiment reproducibility
    :return:
    """
    Debug.log("Calculating feature importances")
    seed = 42
    # step 2 - make predictions for train data and score (higher score - better)
    y_hat_no_shuffle = model.predict_proba(X)
    if y_hat_no_shuffle.shape[1] >1:
        y_hat_no_shuffle = y_hat_no_shuffle[:,1]
    score = scoring_function(*(y, y_hat_no_shuffle))

    # step 3 - calculate permutation importance
    if hasattr(X,'columns'):
        features = X.columns
    else:
        features = get_3d_features(X)
    items = [(key, 0) for key in features]
    importance = collections.OrderedDict(items)

    for n in range(n_repeats):
        for col in features:
            # copy data to avoid using previously shuffled versions
            X_temp = X.copy()

            # shuffle feature_i values
            if hasattr(X,'columns'):
                Debug.log("shuffling table",channel='deep')
                X_temp[col] = X[col].sample(X.shape[0], replace=True, random_state=seed+n).values
            else:
                Debug.log("shuffling cube",channel='deep')
                X_temp = scramble_3d_X(X_temp,col)

            # make prediction for shuffled dataset
            y_hat = model.predict_proba(X_temp)
            if y_hat.shape[1] >1:
                y_hat = y_hat[:,1]

            # calculate score
            score_permuted = scoring_function(*(y, y_hat))
            Debug.log("fe score "+str(score_permuted),channel='deep')

            # calculate delta score
            # better model <-> higher score
            # lower the delta -> more important the feature
            delta_score = score_permuted - score

            # save result
            importance[col] += delta_score / n_repeats
    return importance


def get_timeframe_splits(timeframe: list, max_train_size=10, test_size=1, offset=0):
    """
    Generate time-based cross-validation splits using TimeSeriesSplit.

    Parameters:
    - timeframe (list): List of timestamps or indices representing the time series data.
    - max_train_size (int, optional): Maximum size of the training set in each split.
    - test_size (int, optional): Size of the testing set in each split.
    - offset (int, optional): Number of time periods to offset the splits.

    Returns:
    - splits (generator): Time series splits for the given parameters.
    """
    tscv = TimeSeriesSplit(n_splits=len(timeframe)-1-offset,max_train_size=max_train_size,test_size=test_size)
    splits = tscv.split(timeframe)
    return splits

def get_splits(data, timeframe, max_train_size=None, test_size=1, offset=0):
    """
    Generate An iterable yielding (train, test) splits as arrays 
    of indices for a time series dataset.

    Parameters:
    - data (DataFrame or similar): Time series data.
    - timeframe (list): List of timestamps or indices representing the time series data.
    - max_train_size (int, optional): Maximum size of the training set in each split.
    - test_size (int, optional): Size of the testing set in each split.
    - offset (int, optional): Number of time periods to offset the splits.

    Returns:
    - splits (list of tuples): List of train and test index pairs for each split.
    """
    splits = []
    splits = []
    for tup in get_timeframe_splits(timeframe,max_train_size=max_train_size,test_size=test_size,offset=offset):
        train_waves = np.array(timeframe)[tup[0]]
        train_rowlabels = data.loc[train_waves,:].index
        train_idx = np.array(range(0,data.shape[0]))[data.index.isin(train_rowlabels)]

        test_waves = np.array(timeframe)[tup[1]]
        test_rowlabels = data.loc[test_waves,:].index
        test_idx = np.array(range(0,data.shape[0]))[data.index.isin(test_rowlabels)]
        splits.append((train_idx,test_idx))
    return splits

def X_y_split(data,y_label):
    X = data.drop([y_label],axis=1)
    y = data[y_label]
    return X,y

def get_wave_date_dict(dataset):
    ##Assumes the dataset follows the format of gesis, soep and freda
    return (
        pd.to_datetime(
            dataset['Wave_Date'], format="%Y-%m-%d"
            ).dt.date
        ).groupby('Wave').max().to_dict()

def months_apart(date1:datetime.date,date2:datetime.date):
    return (date2.year - date1.year) * 12 + date2.month - date1.month

def add_months(input_date:datetime.date,months:int):
    return input_date + relativedelta(months=months)

def get_waves_before(dataset,end_date):
    return dataset.loc[
                dataset['Wave_Date'] <= end_date,:
            ].index.get_level_values(0).unique()

def get_equivalent_waves_up_to(
    test_dataset:pd.DataFrame,
    train_dataset:pd.DataFrame,
    test_wave:int
    ):
    """
        Given a wave in the test dataset, calculates how many months
        into the overall survey period that wave came and returns the
        waves in the training dataset that came up to that may months
        ito the training dataset.
        
        Returns: list of wave ints as a pandas Index object. 
    """
    test_wave_date = get_wave_date_dict(test_dataset)[test_wave]
    number_of_months_in = months_apart(
    min(get_wave_date_dict(test_dataset).values()),
    test_wave_date
    )
    equivilent_train_date =add_months(
        min(get_wave_date_dict(train_dataset).values()),
        number_of_months_in
        )
    return get_waves_before(train_dataset,equivilent_train_date)

def get_train_waves_before_test_waves(
    test_dataset:pd.DataFrame,
    train_dataset:pd.DataFrame,
    test_wave:int
    ):
        """
        Given a wave in the test dataset, calculates the date and returns
        a list (pandas Index object) of all waves in the train dataset that 
        came before that date. 
        
        Returns: list of wave ints as a pandas Index object. 
        
        """
        test_wave_date = get_wave_date_dict(test_dataset)[test_wave]
        # deduct last wave because we are using start dates
        return get_waves_before(train_dataset,test_wave_date)[:-1] 
    
def get_idx_from_waves(data:pd.DataFrame,waves:list):
    rowlabels = data.loc[waves,:].index
    return np.array(range(0,data.shape[0]))[data.index.isin(rowlabels)]

def get_train_test_idx(
    test_dataset:pd.DataFrame,
    train_dataset:pd.DataFrame,
    test_wave:int,
    mode:callable,
    max_train_size:int|None
    ):
    """
    Get training and testing indices based on specified mode.

    This function calculates the training and testing indices for time series data
    using a specified mode to select training data points.

    Args:
        test_dataset (pd.DataFrame): The DataFrame containing the testing dataset.
        train_dataset (pd.DataFrame): The DataFrame containing the training dataset.
        test_wave (int): The time series wave or timestamp for the testing dataset.
        mode (callable): A callable function or mode selector for choosing training data.

    Returns:
        tuple: A tuple containing the training indices and testing indices.
    """
    train_waves = mode(
                    test_dataset,
                    train_dataset,
                    test_wave
                )
    if max_train_size != None and len(train_waves) > max_train_size:
        train_waves = train_waves[-max_train_size:]
    
    train_idx = get_idx_from_waves(train_dataset,train_waves)
    test_idx = get_idx_from_waves(test_dataset,[test_wave])
    
    return (train_idx,test_idx)

def get_cross_splits(
    test_dataset:pd.DataFrame,
    train_dataset:pd.DataFrame,
    test_waves:list,
    mode:callable,
    max_train_size:int|None
    ):
    """
    Generate cross-validation splits for time series data.

    This function generates cross-validation splits for time series data by iterating
    over a list of test waves and using a specified mode to select training data for
    each test wave. It returns a list of tuples, each containing training and testing
    indices.

    Args:
        test_dataset (pd.DataFrame): The DataFrame containing the testing dataset.
        train_dataset (pd.DataFrame): The DataFrame containing the training dataset.
        test_waves (list): A list of time series waves or timestamps for testing.
        mode (callable): A callable function or mode selector for choosing training data.

    Returns:
        list: A list of tuples, each containing training and testing indices for cross-validation.
    """
    out = []
    for test_wave in test_waves:
        tup = get_train_test_idx(
                test_dataset,
                train_dataset,
                test_wave,
                mode,
                max_train_size
            )
        out.append(tup)
    return out



def generate_parameter_combinations(param_dict):
    """
    Generate all unique combinations of function parameters from a dictionary.

    Args:
        param_dict (dict): A dictionary of parameter names as keys and lists of possible values as values.

    Returns:
        list of dict: A list of dictionaries, where each dictionary represents a unique combination of parameter values.
    """
    param_names = list(param_dict.keys())
    param_values = list(param_dict.values())
    param_combinations = list(product(*param_values))
    param_dicts = [dict(zip(param_names, combo)) for combo in param_combinations]

    return param_dicts

def csv_to_latex_table(csv_file, caption, label):
    df = pd.read_csv(csv_file)
    columns = df.columns.tolist()

    # Create the header of the table
    header = ' & '.join(columns) + ' \\\\ \\hline\n'
    
    # Create the rows of the table
    rows = []
    for index, row in df.iterrows():
        row_str = ' & '.join(map(str, row.tolist())) + ' \\\\ \\hline\n'
        rows.append(row_str)
    
    # Assemble the LaTeX table string
    table_str = f"""
        \\begin{{table*}}
        \\begin{{threeparttable}}
            \\caption{{{caption}}}
            \\label{{{label}}}
            \\centering
            
            \\begin{{tabular}}{{{'l' * len(columns)}}}
            \\hline
            {header}
            {''.join(rows)}
            \\end{{tabular}}
        \\end{{threeparttable}}
        \\end{{table*}}
    """
    
    return table_str
